Skip to content

[Metal][Performance]: Add split-K for quantized matmul (small M)#3120

Open
Ziqiao-git wants to merge 1 commit intoml-explore:mainfrom
Ziqiao-git:qmm-splitk-small-m
Open

[Metal][Performance]: Add split-K for quantized matmul (small M)#3120
Ziqiao-git wants to merge 1 commit intoml-explore:mainfrom
Ziqiao-git:qmm-splitk-small-m

Conversation

@Ziqiao-git
Copy link

@Ziqiao-git Ziqiao-git commented Feb 12, 2026

Proposed changes

In issue #3086, it was observed that the quantized qmm kernel severely underutilizes the GPU for small M (e.g., M=12-32). For example, a configuration of D=2560 and M=12 yields only 80 threadgroups (assuming BM=BN=32), which is insufficient to saturate the GPU grid.

This PR introduces a split-K variant (qmm_t_splitk) that partitions the K dimension across multiple threadgroups. This safely improves GPU occupancy and execution speed for small-batch inference scenarios, while falling back to the standard kernel for larger batches to prevent any performance regression.

What changed

  • Added a split-K variant of the quantized matrix multiplication kernel (qmm_t_splitk) in the Metal backend, conceptually similar to the existing fp16 steel_gemm_splitk.
  • Updated the dispatch logic in quantized.cpp to dynamically calculate the split factor, targeting ~512 threadgroups for optimal occupancy.
  • Added a fallback mechanism that automatically routes to the regular qmm kernel when split_k <= 1 (e.g., for large M).
  • Verified performance gains on Apple M3 Max (4-bit, group_size=64):
    • D=2560, M=12: 0.079ms -> 0.055ms (~30% faster)
    • D=4096, M=16: 0.155ms -> 0.117ms (~25% faster)
    • No regressions observed for large M configurations.
  • Verified correctness: Existing tests cover the new code path (27 tests, 1639 subtests pass).

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@Ziqiao-git
Copy link
Author

As a quick note on why these specific dimensions matter: The performance bottleneck for small $M$ sizes ($12 \sim 32$) directly impacts the verification step in Speculative Decoding.
By fixing the GPU underutilization here, we significantly speed up the time it takes to evaluate draft tokens. Given how important speculative decoding is for pushing the limits of inference speed on edge devices, this change should provide a meaningful boost to overall generation latency, making the backend more robust for future speculative decoding implementations.

@angeloskath
Copy link
Member

Thanks that is great! I 'll take a look asap.

@Ziqiao-git Ziqiao-git changed the title metal: Add split-K for quantized matmul (small M) [Metal][Performance]: Add split-K for quantized matmul (small M) Feb 13, 2026
Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great but unfortunately unfinished.

The fp quantizations are not implemented (it should be trivial to add based on the qmm_t_splitk_impl), the qmv split-k is not used by anything so it can be removed for starters or if you want you can finish the implementation.

Finally, you don't need a qmm_t_splitk_impl and a qmv_splitk_impl the point of qmv_impl, qmv_fast_impl and qmm_t_impl are that one can adjust the input matrix offsets and call the implementation. See the qvm_splitk for an example.

@Ziqiao-git Ziqiao-git force-pushed the qmm-splitk-small-m branch 2 times, most recently from e7c8390 to 884b86f Compare February 24, 2026 01:27
@Ziqiao-git
Copy link
Author

Addressed all feedback:

  1. Removed unused qmv_split_k code (impl, kernel, macro, and dispatch function)
  2. Removed qmm_t_splitk_implaffine_qmm_t_splitk now pre-offsets pointers and calls qmm_t_impl directly (added K_eff parameter for loop bound, following the qvm_splitk pattern)
  3. Added fp quantization support: fp_qmm_t_impl also takes K_eff, new fp_qmm_t_splitk kernel wrapper, and instantiation macros in fp_quantized.metal

Build passes, pre-commit clean, benchmark confirms split-K working correctly for both affine and fp paths.

Benchmark Results (M1/M2/M3 / applegpu_g15s) Device: applegpu_g15s Memory: 52 GB

qmv_batch_limit(D=4096, O=4096) = 12

============================================================
D=4096 mode=mxfp8 (bits=8, group_size=32)
M fp16 quant ratio fp16 kernel quant kernel


1 0.475ms 0.036ms 0.08x gemv qmv
2 0.519ms 0.061ms 0.12x split-K qmv
4 0.533ms 0.111ms 0.21x split-K qmv
8 0.516ms 0.212ms 0.41x split-K qmv
10 0.520ms 0.262ms 0.50x split-K qmv
12 0.518ms 0.121ms 0.23x split-K qmm_splitk
14 0.521ms 0.122ms 0.23x split-K qmm_splitk
16 0.515ms 0.122ms 0.24x split-K qmm_splitk
20 0.584ms 0.121ms 0.21x split-K qmm_splitk
32 0.526ms 0.125ms 0.24x split-K qmm_splitk
2048 6.318ms 6.646ms 1.05x regular GEMM qmm

(Tested across affine, mxfp8, and mxfp4 - truncated for brevity but all show similar smooth transitions).

Let me know if there is anything I missed.

Copy link
Member

@jagrit06 jagrit06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting a few small changes, thanks for putting this together!

Comment on lines +788 to +802
// Choose split_k to target ~512 threadgroups
int bm = 32, bn = 32;
int n_tiles = (N + bn - 1) / bn;
int m_tiles = (M + bm - 1) / bm;
int current_tgs = n_tiles * m_tiles;
int split_k = std::max(1, 512 / current_tgs);

// Ensure K divides evenly by split_k * group_size
while (split_k > 1 && (K % (split_k * group_size) != 0)) {
split_k--;
}
if (split_k <= 1) {
return qmm(
x, w, scales, biases, out, true, group_size, bits, M, N, K, d, s, mode);
}
Copy link
Member

@jagrit06 jagrit06 Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you help me understand why this approach was chosen ?
I can make some sense of 512 thread groups being the target (if the M or N are larger, we go to the regular matmul, if the K is larger still we just loop more inside the kernel), but would like to know how you ended up at that number

Also, we do employ checks such that K is always divisible by group_size - so wouldn't the looping step just then come down to be equivalent to split_k = std::max(split_k, K / group_size)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the 512 target, it was mostly an empirical choice trying to balance GPU occupancy with atomic reduction overhead. My mental math was that on the larger Apple Silicon chips (like the 76-core M2 Ultra), we generally need a handful of active threadgroups per core to effectively hide memory latency (roughly $76 \times 6 \approx 456$). So, aiming for ~512 felt like a safe baseline to ensure good saturation even on the high-end chips. On the flip side, I wanted to avoid pushing the split factor too high to keep the atomic add contention in check during the final accumulation, as over-splitting the K-dimension can quickly offset the occupancy gains.Ideally, this target should be a dynamic function of the active GPU core count. However, since dynamically querying the exact core count in Metal isn't straightforward, tuning for the high-end ceiling acts as a safe catch-all. The hardware scheduler on smaller chips (like the base M1/M2) handles the oversubscription of 512 threadgroups gracefully without much penalty, whereas under-subscribing an Ultra would be catastrophic for latency hiding. If MLX has a preferred internal way to dynamically scale split-K factors based on device tiers, I'm completely open to refactoring this!

Copy link
Author

@Ziqiao-git Ziqiao-git Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally avoided just using split_k = std::min(split_k, K / group_size) because we need to guarantee quantization group alignment, not just an upper bound.

If we strictly use std::min, it might result in a split_k that doesn't perfectly divide K into multiples of group_size. For example, if K=1024, group_size=64, and the calculated split_k is 12, the chunk size 1024 / 12 does not align with 64(the quant group). If a threadgroup's chunk starts unaligned with group_size, it will read misaligned scale and bias values, leading to incorrect numerical results.

The while loop ensures we step down until (K / split_k) % group_size == 0 so that every threadgroup boundary perfectly aligns with the quantization blocks. Let me know if you think there is a more elegant way to enforce this modulo alignment mathematically instead of the loop!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point

I think the way around this would be support unaligned K_eff for the last partition (that logic doesn't exist in the QMMs but does in the regular MM code) - we certainly don't want it to be the case that you end up with a K that divides to 31 quantized groups, but doesn't end up dispatched to the split-K variant because 31 is a prime and we require the number of splits to perfectly divide K / group_size

That said, we can merge this for now and do that in a follow up PR

Add a split-K variant for quantized matrix multiplication that
partitions the K dimension across threadgroups when GPU occupancy
is low (small M).

- Reuse qmm_t_impl with a K_eff parameter for the loop bound,
  pre-offset pointers in the splitk wrapper (following qvm_splitk pattern)
- Remove unused qmv_split_k code
- Add fp quantization support (fp_qmm_t_splitk)
- Dynamic split_k selection targeting ~512 threadgroups
- Fallback to regular qmm when split_k <= 1
Comment on lines +789 to +793
int bm = 32, bn = 32;
int n_tiles = (N + bn - 1) / bn;
int m_tiles = (M + bm - 1) / bm;
int current_tgs = n_tiles * m_tiles;
int split_k = std::max(1, 512 / current_tgs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, just so we are safe, could you add a check here to make sure that K is large enough (in relation to M and N) to warrant going through the loop of splits ?
We naturally short-circuit if n_tiles * m_tiles >= 512 so large M and N are covered - it would be good to similarly short circuit out if the K isn't too large compared to the M and N.

Till we have a fix for the loop, I would like to avoid to a 64x64x128 matmul as an example, going through a 100 iterations where it could have just short circuited earlier

After that, we should be ready to merge!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants